
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from torch import nn
import torch.nn.functional  as F
import matplotlib.pyplot as plt
import matplotlib


class CNN(nn.Module):
    def __init__(self, m=50, d=1000, q=1,linear=False):
        super(CNN, self).__init__()

        self.q = q
        self.linear = linear
        self.W0 = torch.nn.Parameter(torch.randn(d, m))
        self.W0.requires_grad = True
        self.Wp = torch.nn.Parameter(torch.randn(m, m))
        self.Wp.requires_grad = True
        self.Wn = torch.nn.Parameter(torch.randn(m, m))
        self.Wn.requires_grad = True

        nn.init.normal_(self.W0, std=0.001)
        #nn.init.normal_(self.Wp, std=1)
        #nn.init.normal_(self.Wn, std=1)


    def act(self,input):
        if self.linear:
            return input

        return torch.pow(F.relu(input),self.q)

    def forward(self, x1, x2):
        z1 = self.act(torch.mm(x1, self.W0))
        z2 = self.act(torch.mm(x2, self.W0))

        Fp = torch.mean(self.act(torch.mm(z1, self.Wp)), 1) \
            + torch.mean(self.act(torch.mm(z2, self.Wp)), 1)
        Fn = torch.mean(self.act(torch.mm(z1, self.Wn)), 1) \
            + torch.mean(self.act(torch.mm(z2, self.Wn)), 1)
        out = Fp - Fn
        return out


def prepare_data():

    train_y = torch.cat((torch.ones(int(n_train/2)), -torch.ones(int(n_train/2))))
    test_y = torch.cat((torch.ones(int(n_test/2)), -torch.ones(int(n_test/2))))

    feature = torch.zeros(d, 1)
    feature[0] =  1.0
    train_x1 = torch.matmul(train_y.unsqueeze(0).T, feature.T)
    test_x1 = torch.matmul(test_y.unsqueeze(0).T, feature.T)
    train_x2 = torch.randn(n_train, d) * 1 # 0.5
    test_x2 = torch.randn(n_test, d) * 1 # 0.5

    return train_x1, train_x2, train_y, test_x1, test_x2, test_y

# seed = 3407
seed = 2023
n_train = 200
n_test = 2000
d = 2000
n_epoch = 5000

np.random.seed(seed)
torch.manual_seed(seed)

train_x1, train_x2, train_y, test_x1, test_x2, test_y = prepare_data()

sample_size = n_train
data_loader = DataLoader(TensorDataset(
    train_x1,
    train_x2,
    train_y
), batch_size=int(250), shuffle=True)


width = 20
learning_rate = 0.5

model = CNN(m=width, d=d)

optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

train_f_preds = []
loss_derivatives = []


train_loss_values = []
test_loss_values = []
train_acc_values = []

test_acc_values = []
feature_learning = []
noise_memorization_p = np.zeros(( width, n_train, n_epoch))
feature_learning_p = np.zeros(( width,  n_epoch))

noise_memorization_n = np.zeros(( width, n_train, n_epoch))
feature_learning_n = np.zeros(( width,  n_epoch))

for ep in range(n_epoch):
    train_loss = 0

    loss_d = []

    for sample_x1, sample_x2, sample_y in data_loader:

        model.train()
        optimizer.zero_grad()
        f_pred = model.forward(sample_x1, sample_x2)
        noise = 2*torch.bernoulli(torch.ones_like(sample_y)*0.9)-1
        # noise = 1
        loss = torch.log(torch.add(torch.exp(-f_pred * sample_y * noise), 1)).mean()
        loss_d.append(1/(torch.exp(f_pred*sample_y*noise) + 1))

        loss.backward()
        optimizer.step()
        model.eval()
        train_loss += sample_size * loss.item()


    feature_learning_p[:, ep] =  (torch.matmul(model.W0.T, train_x1[0])).detach().numpy()
    noise_memorization_p[:,:, ep] =  (torch.matmul(model.W0.T, train_x2.T)).detach().numpy()

    feature_learning_n[:, ep] =  (torch.matmul(model.W0.T, train_x1[0])).detach().numpy()
    noise_memorization_n[:,:, ep] =  (torch.matmul(model.W0.T, train_x2.T)).detach().numpy()


    train_loss /= n_train
    train_loss_values.append(train_loss)
    f_pred_test = model.forward(test_x1, test_x2)
    f_pred_train = model.forward(train_x1, train_x2)

    train_f_preds.append(f_pred_train)
    loss_derivatives.append(torch.cat(loss_d))

    # print(torch.cat(loss_d).shape)

    pred_binary_train = (f_pred_train > 0).float() * 2 - 1
    pred_binary_test = (f_pred_test > 0).float() * 2 - 1


    correct_preds_train = (pred_binary_train == train_y).float().mean()
    correct_preds_test = (pred_binary_test == test_y).float().mean()

    train_acc_values.append(correct_preds_train.item())
    test_acc_values.append(correct_preds_test.item())

    test_loss = torch.log(torch.add(torch.exp(-f_pred_test * test_y), 1)).mean()
    test_loss_values.append(test_loss.item())

    print(f'[{ep+1}|{n_epoch}] train_loss={train_loss:0.5e}, test_loss={test_loss:0.5e}')


model = CNN(m=width, d=d)

optimizer = torch.optim.SGD(model.parameters(), lr =learning_rate)

train_loss_values_sgd = []
test_loss_values_sgd = []
train_acc_values_sgd = []



test_acc_values_sgd = []
feature_learning = []



noise_memorization_p_sgd = np.zeros(( width, n_train, n_epoch))
feature_learning_p_sgd = np.zeros(( width,  n_epoch))

noise_memorization_n_sgd = np.zeros(( width, n_train, n_epoch))
feature_learning_n_sgd = np.zeros(( width,  n_epoch))

for ep in range(n_epoch):
    train_loss = 0
    for sample_x1, sample_x2, sample_y in data_loader:

        model.train()
        optimizer.zero_grad()
        f_pred = model.forward(sample_x1, sample_x2)
        loss = torch.log(torch.add(torch.exp(-f_pred * sample_y ), 1)).mean()

        loss.backward()
        optimizer.step()
        model.eval()
        train_loss += sample_size * loss.item()


    feature_learning_p_sgd[:, ep] =  (torch.matmul(model.W0.T, train_x1[0])).detach().numpy()
    noise_memorization_p_sgd[:,:, ep] =  (torch.matmul(model.W0.T, train_x2.T)).detach().numpy()

    feature_learning_n_sgd[:, ep] =  (torch.matmul(model.W0.T, train_x1[0])).detach().numpy()
    noise_memorization_n_sgd[:,:, ep] =  (torch.matmul(model.W0.T, train_x2.T)).detach().numpy()


    train_loss /= n_train
    train_loss_values_sgd.append(train_loss)
    f_pred_test = model.forward(test_x1, test_x2)
    f_pred_train = model.forward(train_x1, train_x2)



    pred_binary_train = (f_pred_train > 0).float() * 2 - 1
    pred_binary_test = (f_pred_test > 0).float() * 2 - 1


    correct_preds_train = (pred_binary_train == train_y).float().mean()
    correct_preds_test = (pred_binary_test == test_y).float().mean()

    train_acc_values_sgd.append(correct_preds_train.item())
    test_acc_values_sgd.append(correct_preds_test.item())

    test_loss = torch.log(torch.add(torch.exp(-f_pred_test * test_y), 1)).mean()
    test_loss_values_sgd.append(test_loss.item())

    print(f'[{ep+1}|{n_epoch}] train_loss={train_loss:0.5e}, test_loss={test_loss:0.5e}')



matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['font.size'] = 16

# Create a 1x2 grid of subplots
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(25, 4.5))

# Plot the matrix_norm_numpy values in the first subplot

feature_learning_tensor = torch.tensor(feature_learning)
feature_learning_numpy_array = feature_learning_tensor.detach().numpy()




# Plot the loss_values in the first subplot

noise_pseris = np.max(np.abs(noise_memorization_p), axis=0)
noise_nseris = np.max(np.abs(noise_memorization_n), axis=0)
noise_mseris = np.maximum(noise_pseris, noise_nseris)

feature_pseris = np.max(np.abs(feature_learning_p), axis=0)
feature_nseris = np.max(np.abs(feature_learning_n), axis=0)
feature_mseris = np.maximum(feature_pseris, feature_nseris)


noise_pseris_sgd = np.max(np.abs(noise_memorization_p_sgd), axis=0)
noise_nseris_sgd = np.max(np.abs(noise_memorization_n_sgd), axis=0)
noise_mseris_sgd = np.maximum(noise_pseris_sgd, noise_nseris_sgd)

feature_pseris_sgd = np.max(np.abs(feature_learning_p_sgd), axis=0)
feature_nseris_sgd = np.max(np.abs(feature_learning_n_sgd), axis=0)
feature_mseris_sgd = np.maximum(feature_pseris_sgd, feature_nseris_sgd)


#ax1.plot((noise_mseris[flip[1]].T), color = 'red', label = r'$\max_{j,r} \rho_{j,r,i}, y_i \neq \tilde{y}_i $' )
# ax1.plot((noise_mseris[flip].T), color = 'red' )

ax1.plot((noise_mseris_sgd[0].T), color = 'tab:red',linestyle='--', linewidth =2, label = r'$\max_{j,r} \rho_{j,r,i}$ (GD)' )
ax1.plot((feature_mseris_sgd), linewidth =2, label = r'$\max_{j,r} \gamma_{j,r}$ (GD)', color='tab:red')
ax1.plot((noise_mseris[0].T), linewidth =2, color = 'tab:blue', linestyle='--', label = r'$\max_{j,r} \rho_{j,r,i}$ (Label Noise GD)' )
ax1.plot((feature_mseris), linewidth =2, label = r'$\max_{j,r} \gamma_{j,r}$ (Label Noise GD)', color='tab:blue')

# ax1.set_yscale('log')
# ax1.set_xscale('log')


ax1.set_title('Feature Learning')
ax1.set_xlabel('t',fontsize=20)
ax1.tick_params(axis='both', which='major', labelsize=20)
ax1.legend()


ax2.plot(noise_mseris[0].T/feature_mseris, linewidth =2, color = 'tab:blue', label='Label Noise GD')
ax2.plot(noise_mseris_sgd[0].T/feature_mseris_sgd, linewidth =2, color = 'tab:red', label='GD')
ax2.set_title(r'Feature Learning Ratio ($\rho/\gamma$)')
ax2.set_xlabel('t',fontsize=20)
ax2.tick_params(axis='both', which='major', labelsize=20)
ax2.legend()

ax3.plot(train_loss_values, linewidth =2, label='Label Noise GD', color = 'tab:blue')
ax3.plot(train_loss_values_sgd, linewidth =2,label='GD',  color='tab:red')
ax3.set_title('Train Loss')
ax3.set_xlabel('t',fontsize=20)

ax3.legend()
ax3.tick_params(axis='both', which='major', labelsize=20)
# ax4.plot(train_acc_flip_values, label='train_flip',color = 'red')
# ax3.plot(train_acc_unflip_values, label='train_unflip',color = 'green')
ax4.plot(test_acc_values,linewidth =2, label='Label Noise GD',color = 'tab:blue')
ax4.plot(test_acc_values_sgd, linewidth =2, label='GD', color='tab:red')
ax4.set_title('Test Accuracy')
ax4.set_xlabel('t',fontsize=20)
ax4.legend()
#ax4.plot([0, 2000], [1, 1], label='y=1')
ax4.tick_params(axis='both', which='major', labelsize=20)
fig.savefig('two_layer_results.png', dpi=300, bbox_inches='tight')

plt.show()